eddc54
@@ -333,11 +333,8 @@
public static MapJoinOperator convertMapJoin(
       }
     }
 
-    RowResolver oldOutputRS = opParseCtxMap.get(op).getRowResolver();
-    RowResolver outputRS = new RowResolver();
-    ArrayList<String> outputColumnNames = new ArrayList<String>();
+    RowResolver outputRS = opParseCtxMap.get(op).getRowResolver();
     Map<Byte, List<ExprNodeDesc>> keyExprMap = new HashMap<Byte, List<ExprNodeDesc>>();
-    Map<Byte, List<ExprNodeDesc>> valueExprMap = new HashMap<Byte, List<ExprNodeDesc>>();
 
     // Walk over all the sources (which are guaranteed to be reduce sink
     // operators).
@@ -349,7 +346,6 @@
public static MapJoinOperator convertMapJoin(
       new ArrayList<Operator<? extends OperatorDesc>>();
     List<Operator<? extends OperatorDesc>> oldReduceSinkParentOps =
        new ArrayList<Operator<? extends OperatorDesc>>();
-    Map<String, ExprNodeDesc> colExprMap = new HashMap<String, ExprNodeDesc>();
 
     // found a source which is not to be stored in memory
     if (leftSrc != null) {
@@ -385,37 +381,34 @@
public static MapJoinOperator convertMapJoin(
       keyExprMap.put(pos, keys);
     }
 
-    // create the map-join operator
-    for (pos = 0; pos < newParentOps.size(); pos++) {
-      RowResolver inputRS = opParseCtxMap.get(newParentOps.get(pos)).getRowResolver();
-      List<ExprNodeDesc> values = new ArrayList<ExprNodeDesc>();
-
-      Iterator<String> keysIter = inputRS.getTableNames().iterator();
-      while (keysIter.hasNext()) {
-        String key = keysIter.next();
-        HashMap<String, ColumnInfo> rrMap = inputRS.getFieldMap(key);
-        Iterator<String> fNamesIter = rrMap.keySet().iterator();
-        while (fNamesIter.hasNext()) {
-          String field = fNamesIter.next();
-          ColumnInfo valueInfo = inputRS.get(key, field);
-          ColumnInfo oldValueInfo = oldOutputRS.get(key, field);
-          if (oldValueInfo == null) {
-            continue;
-          }
-          String outputCol = oldValueInfo.getInternalName();
-          if (outputRS.get(key, field) == null) {
-            outputColumnNames.add(outputCol);
-            ExprNodeDesc colDesc = new ExprNodeColumnDesc(valueInfo.getType(), valueInfo
-                .getInternalName(), valueInfo.getTabAlias(), valueInfo.getIsVirtualCol());
-            values.add(colDesc);
-            outputRS.put(key, field, new ColumnInfo(outputCol, valueInfo.getType(), valueInfo
-                .getTabAlias(), valueInfo.getIsVirtualCol(), valueInfo.isHiddenVirtualCol()));
-            colExprMap.put(outputCol, colDesc);
-          }
+    // removing RS, only ExprNodeDesc is changed (key/value/filter exprs and colExprMap)
+    // others (output column-name, RR, schema) remain intact
+    Map<String, ExprNodeDesc> colExprMap = op.getColumnExprMap();
+    List<String> outputColumnNames = op.getConf().getOutputColumnNames();
+
+    List<ColumnInfo> schema = new ArrayList<ColumnInfo>(op.getSchema().getSignature());
+
+    Map<Byte, List<ExprNodeDesc>> valueExprs = op.getConf().getExprs();
+    Map<Byte, List<ExprNodeDesc>> newValueExprs = new HashMap<Byte, List<ExprNodeDesc>>();
+    for (Map.Entry<Byte, List<ExprNodeDesc>> entry : valueExprs.entrySet()) {
+      byte tag = entry.getKey();
+      Operator<?> terminal = oldReduceSinkParentOps.get(tag);
+
+      List<ExprNodeDesc> values = entry.getValue();
+      List<ExprNodeDesc> newValues = ExprNodeDescUtils.backtrack(values, op, terminal);
+      newValueExprs.put(tag, newValues);
+      for (int i = 0; i < schema.size(); i++) {
+        ColumnInfo column = schema.get(i);
+        if (column == null) {
+          continue;
+        }
+        ExprNodeDesc expr = colExprMap.get(column.getInternalName());
+        int index = ExprNodeDescUtils.indexOf(expr, values);
+        if (index >= 0) {
+          colExprMap.put(column.getInternalName(), newValues.get(index));
+          schema.set(i, null);
         }
       }
-
-      valueExprMap.put(Byte.valueOf((byte) pos), values);
     }
 
     Map<Byte, List<ExprNodeDesc>> filters = desc.getFilters();
@@ -456,7 +449,7 @@
public static MapJoinOperator convertMapJoin(
 
     int[][] filterMap = desc.getFilterMap();
     for (pos = 0; pos < newParentOps.size(); pos++) {
-      List<ExprNodeDesc> valueCols = valueExprMap.get(Byte.valueOf((byte) pos));
+      List<ExprNodeDesc> valueCols = newValueExprs.get(pos);
       int length = valueCols.size();
       List<ExprNodeDesc> valueFilteredCols = new ArrayList<ExprNodeDesc>(length);
       // deep copy expr node desc
@@ -492,7 +485,7 @@
public static MapJoinOperator convertMapJoin(
     } else {
       dumpFilePrefix = "mapfile"+PlanUtils.getCountForMapJoinDumpFilePrefix();
     }
-    MapJoinDesc mapJoinDescriptor = new MapJoinDesc(keyExprMap, keyTableDesc, valueExprMap,
+    MapJoinDesc mapJoinDescriptor = new MapJoinDesc(keyExprMap, keyTableDesc, newValueExprs,
         valueTableDescs, valueFiltedTableDescs, outputColumnNames, mapJoinPos, joinCondns,
         filters, op.getConf().getNoOuterJoin(), dumpFilePrefix);
     mapJoinDescriptor.setTagOrder(tagOrder);
